{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 0
   },
   "source": [
    "# The Object Detection Dataset\n",
    "\n",
    "There are no small datasets, like MNIST or Fashion-MNIST, in the object detection field. In order to quickly test models, we are going to assemble a small dataset. First, we generate 1000 banana images of different angles and sizes using free bananas from our office. Then, we collect a series of background images and place a banana image at a random position on each image.\n",
    "\n",
    "## Downloading the Dataset\n",
    "\n",
    "The banana detection dataset in RecordIO format can be downloaded directly from the Internet.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load ../utils/djl-imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import ai.djl.basicdataset.cv.*;\n",
    "import java.awt.*;\n",
    "import java.awt.image.*;\n",
    "import java.util.List;\n",
    "import javax.swing.*;\n",
    "import ai.djl.modality.cv.output.Rectangle;"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 2,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Reading the Dataset\n",
    "\n",
    "We are going to read the object detection dataset by creating the instance `BananaDetection`. DJL makes it fairly easy to get the dataset. Here is how we do it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "// Load the bananas dataset.\n",
    "BananaDetection trainIter = BananaDetection.builder()\n",
    "        .setSampling(32, true)  // Read the dataset in random order\n",
    "        .optUsage(Dataset.Usage.TRAIN)\n",
    "        .build();\n",
    "\n",
    "trainIter.prepare();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 4,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Below, we read a minibatch and print the shape of the image and label. The shape of the image is the same as in the previous experiment (batch size, number of channels, height, width). The shape of the label is (batch size, $m$, 5), where $m$ is equal to the maximum number of bounding boxes contained in a single image in the dataset. Although computation for the minibatch is very efficient, it requires each image to contain the same number of bounding boxes so that they can be placed in the same batch. Since each image may have a different number of bounding boxes, we can add illegal bounding boxes to images that have less than $m$ bounding boxes until each image contains $m$ bounding boxes. Thus, we can read a minibatch of images each time. The label of each bounding box in the image is represented by an array of length 5. The first element in the array is the category of the object contained in the bounding box. When the value is -1, the bounding box is an illegal bounding box for filling purpose. The remaining four elements of the array represent the $x, y$ axis coordinates of the upper-left corner of the bounding box and the $x, y$ axis coordinates of the lower-right corner of the bounding box (the value range is between 0 and 1). The banana dataset here has only one bounding box per image, so $m=1$.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "4"
    },
    "origin_pos": 5,
    "pycharm": {
     "name": "#%%\n"
    },
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "NDManager manager = NDManager.newBaseManager();\n",
    "\n",
    "Batch batch = trainIter.getData(manager).iterator().next();\n",
    "System.out.println(batch.getData().get(0).getShape() + \", \" + batch.getLabels().get(0).getShape());"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 6,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Demonstration\n",
    "\n",
    "We have ten images with bounding boxes on them. We can see that the angle, size, and position of banana are different in each image. Of course, this is a simple artificial dataset. In actual practice, the data are usually much more complicated.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "public class ImagePanel extends JPanel {\n",
    "    int SCALE;\n",
    "    BufferedImage img;\n",
    "\n",
    "    public ImagePanel() {\n",
    "        this.SCALE = 1;\n",
    "    }\n",
    "\n",
    "    public ImagePanel(int scale, BufferedImage img) {\n",
    "        this.SCALE = scale;\n",
    "        this.img = img;\n",
    "    }\n",
    "\n",
    "    @Override\n",
    "    protected void paintComponent(Graphics g) {\n",
    "        Graphics2D g2d = (Graphics2D) g;\n",
    "        g2d.scale(SCALE, SCALE);\n",
    "        g2d.drawImage(this.img, 0, 0, this);\n",
    "    }\n",
    "}\n",
    "\n",
    "public class Container extends JPanel {\n",
    "    public Container(String label) {\n",
    "        setLayout(new BoxLayout(this, BoxLayout.Y_AXIS));\n",
    "        JLabel l = new JLabel(label, JLabel.CENTER);\n",
    "        l.setAlignmentX(Component.CENTER_ALIGNMENT);\n",
    "        add(l);\n",
    "    }\n",
    "\n",
    "    public Container(String trueLabel, String predLabel) {\n",
    "        setLayout(new BoxLayout(this, BoxLayout.Y_AXIS));\n",
    "        JLabel l = new JLabel(trueLabel, JLabel.CENTER);\n",
    "        l.setAlignmentX(Component.CENTER_ALIGNMENT);\n",
    "        add(l);\n",
    "        JLabel l2 = new JLabel(predLabel, JLabel.CENTER);\n",
    "        l2.setAlignmentX(Component.CENTER_ALIGNMENT);\n",
    "        add(l2);\n",
    "    }\n",
    "}\n",
    "    \n",
    "public static void showImages(Image[] dataset,\n",
    "                                  int number, int WIDTH, int HEIGHT, int SCALE,\n",
    "                                  NDManager manager)\n",
    "            throws IOException, TranslateException {\n",
    "    // Plot a list of images\n",
    "    JFrame frame = new JFrame(\"\");\n",
    "    for (int record = 0; record < number; record++) {\n",
    "        Image i = dataset[record];\n",
    "        BufferedImage img = (BufferedImage) i.getWrappedImage();\n",
    "        Graphics2D g = (Graphics2D) img.getGraphics();\n",
    "\n",
    "        JPanel panel = new ImagePanel(SCALE, img);\n",
    "        panel.setPreferredSize(new Dimension(WIDTH * SCALE, HEIGHT * SCALE));\n",
    "        JPanel container = new Container(\"\");\n",
    "        container.add(panel);\n",
    "        frame.getContentPane().add(container);\n",
    "    }\n",
    "    frame.getContentPane().setLayout(new FlowLayout());\n",
    "    frame.pack();\n",
    "    frame.setVisible(true);\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "5"
    },
    "origin_pos": 7,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "Image[] imageArr = new Image[10];\n",
    "List<List<String>> classNames = new ArrayList();\n",
    "List<List<Double>> prob = new ArrayList<>();\n",
    "List<List<BoundingBox>> boxes = new ArrayList<>();\n",
    "\n",
    "Batch batch = trainIter.getData(manager).iterator().next();\n",
    "    for (int i=0; i < 10; i++){\n",
    "        NDArray imgData = batch.getData().get(0).get(i);\n",
    "        imgData.muli(255);\n",
    "        NDArray imgLabel = batch.getLabels().get(0).get(i);\n",
    "                \n",
    "        List<String> bananaList = new ArrayList<>();\n",
    "        bananaList.add(\"banana\");\n",
    "        classNames.add(new ArrayList<>(bananaList));\n",
    "                \n",
    "        List<Double> probabilityList = new ArrayList<>();\n",
    "        probabilityList.add(1.0);\n",
    "        prob.add(new ArrayList<>(probabilityList));\n",
    "                \n",
    "        List<BoundingBox> boundBoxes = new ArrayList<>();\n",
    "                \n",
    "        float[] coord = imgLabel.get(0).toFloatArray();\n",
    "        double first = (double) (coord[1]);\n",
    "        double second = (double) (coord[2]);\n",
    "        double third = (double) (coord[3]);\n",
    "        double fourth = (double) (coord[4]);\n",
    "\n",
    "        boundBoxes.add(new Rectangle(first, second, (third-first), (fourth-second)));\n",
    "                \n",
    "        boxes.add(new ArrayList<>(boundBoxes));\n",
    "        DetectedObjects detectedObjects = new DetectedObjects(classNames.get(i), prob.get(i), boxes.get(i));\n",
    "        imageArr[i] = ImageFactory.getInstance().fromNDArray(imgData.toType(DataType.INT8, true));\n",
    "        imageArr[i].drawBoundingBoxes(detectedObjects);\n",
    "    }\n",
    "\n",
    "// refer to https://github.com/deepjavalibrary/d2l-java/tree/master/documentation/troubleshoot.md \n",
    "// if you encounter X11 errors when drawing bounding boxes.\n",
    "showImages(imageArr, 10, 256, 256, 1, manager)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "![Contour Gradient Descent.](https://d2l-java-resources.s3.amazonaws.com/img/object_detection.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 8,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Summary\n",
    "\n",
    "* The banana detection dataset we synthesized can be used to test object detection models.\n",
    "* The data reading for object detection is similar to that for image classification. However, after we introduce bounding boxes, the label shape and image augmentation (e.g., random cropping) are changed.\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "14.0.2+12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
